import torchvision.transforms as transforms
from torchvision import datasets
import importlib
def get_dataset(args, testset=True):
    if args.dataset in ['MNIST','CIFAR10','CIFAR100']:
        transform = transforms.Compose([
        transforms.ToTensor(),  # 转换为Tensor
        transforms.Normalize((0.5,), (0.5,))  # 标准化
        ])
        trainset = getattr(importlib.import_module('torchvision.datasets'),args.dataset)(root=args.dataset_path, train=True, transform=transform)
        if testset:
            testset = getattr(importlib.import_module('torchvision.datasets'),args.dataset)(root=args.dataset_path, train=False, transform=transform)
    if args.dataset=='ImageNet':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        trainset = datasets.ImageNet(root=args.dataset_path,split='train', transform=transform_train)
        if testset:
            transform_val = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
            testset = datasets.ImageNet(root=args.dataset_path, split='val', transform=transform_val)
    return (trainset, testset) if testset else trainset


# "dataset_path":
# cifar10,cifar100,mnist "~/KO/datasets",
# imagenet '/home/l6eub2ic/whcs-share31/zhaotong/datasets/imagenet2012'